import json
import re
from collections import defaultdict
from beartype.typing import Any, TypedDict, Union
from bs4 import BeautifulSoup, Tag, Comment
import html
import numpy as np
from functools import lru_cache
import transformers
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys
import random
from lxml import html as lxml_html


HF_TOKEN = ""

interactive_elements_selectors  = {
    "body", 'a', 'mark', 'div', 'p', 'svg', 'td', 'path', 'button', 'span', 'input', 'select', 'label', 'use', 'lightning-primitive-icon', 'li', 'h3', 'textarea', 'var', 'canvas', 'lightning-base-combobox-item', 'strong', 'h2', 'img', 'i18n-string', 'i', 'h4', 'footer', 'b', 'pre', 'rect', 'ul', 'cf-column', 'h5', 'image', 'circle', 'kat-badge', 'h1', 'sg-widget-container', 'lightning-input', 'dd', 'small', 'mfe-header-pharos-dropdown-menu-item', 'search-results-vue-pharos-text-input', 'search-results-vue-pharos-checkbox', 'search-results-vue-pharos-select', 'fieldset', 'search-results-vue-pharos-button', 'mfe-citation-pharos-button', 'dt', 'lightning-formatted-number', 'form', 'font', 'article', 'slot', 'lightning-formatted-date-time', 'ui5-link', 'lightning-icon', 'kat-input', 'kat-icon', 'kat-link', 'mat-label', 'section', 'option', 'bc-autocomplete', 'video', 'application', 'lyte-drop-button', 'lyte-text', 'lyte-exptable', 'lyte-yield', 'polygon', 'main', 'mat-icon', 'flutter-view', 'lightning-formatted-name', 'lyte-icon', 'lyte-drop-item', 'abbr', 'lightning-base-combobox-formatted-text', 'mwc-button', 'colab-run-button', 'md-icon', 'lyte-exptable-td', 'address', 'd2l-menu-item-link', 'd2l-button-icon', 'd2l-input-date-time', 'th', 'table', 'tr', 'yc-i18n', 'ol', 'u', 'lyte-menu-item', 'flowruntime-lwc-body', 'lightning-formatted-text', 'iframe', 'em', 'c-wiz', 'trix-editor', 'ab-app-nav-link', 'ab-nav-x-link', 'h6', 'flexipage-column2', 'bdi', 'xweb-shellbar', 'artdeco-dropdown-item', 'seamless-integration', 'c-icon', 'header', 'records-highlights-details-item', 'flowruntime-screen-field', 'lightning-primitive-file-droppable-zone', 'd2l-my-courses', 'd2l-cs-app', 'bsp-line', 'flexipage-record-home-scrollable-column', 'kat-textarea', 'kat-button', 'o-textarea', 'flowruntime-lwc-header'
}

valid_tags = {
	'div', 'body', 'span', 'svg', 'input', 'img', 'p', 'a', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'b', 'i', 'u', 'strong', 'em', 'abbr', 'cite', 'q', 'code', 'ins', 'var', 'area', 'ul', 'li', 'ol', 'dl', 'dt', 'dd', 'form', 'button', 'col', 'textarea', 'path', 'lightning-primitive-icon', 'select', 'label', 'td', 'canvas', 'circle', 'i18n-string', 'table', 'tr', 'image', 'footer', 'use', 'option', 'rect', 'mark', 'section', 'th', 'polygon', 'aside', 'main', 'header', 'pre', 'figure'
}
code_elements_to_decompose = {
	'style', 'script'
}

salient_attributes = {
	"alt",
	"aria-role",
	"aria-label",
	"option_selected",
	"placeholder",
	"role",
	"type",
	"node",
	"desc",
	"label",
	"input",
	"name",
	"title",
	"text",
	"value",
	"href",
	"expanded",
	"required",
	"selected",
	"id",
	"class"
}

sub_salient_attributes = {
    "desc",
    "label",
    "input",
    "name",
    "title",
    "text",
    "value",
}


sys.setrecursionlimit(16000)


tokenizer = AutoTokenizer.from_pretrained(
	"mistralai/Mistral-7B-Instruct-v0.3",
	cache_dir="./",
	token=HF_TOKEN
)


@lru_cache(maxsize=2**12)
def token_ratio(window):
	return float(len(window)) / (len(tokenizer(window, add_special_tokens=False)["input_ids"]) + 1e-5)

def clean_dom(domstr):
    domstr = html.unescape(domstr)
    try:
        domstr = bytes(domstr, "utf-8").decode("unicode_escape")
    except:
        pass
    des = domstr.replace("–", '-')
    des = des.replace("•", '-')
    des = des.replace("’", '\'')
    des = des.replace("‹", '<')
    des = des.replace("×", '*')
    des = des.replace("·", '.')
    des = des.replace("”","\"")
    des = des.replace("&amp;","&")
    des = des.replace("&lt;","<")
    des = des.replace("&gt;",">")
    domstr = des.replace("＋", '+')
    domstr = re.sub(r'[^\x00-\x7F]+',' ', domstr)
    domstr = re.sub(u'[^\u0020-\uD7FF\u0009\u000A\u000D\uE000-\uFFFD\U00010000-\U0010FFFF]+', ' ', domstr)
    domstr = re.sub(r"\s+", " ", domstr)
    pattern = re.compile(r'[\ue000-\uf8ff]')
    domstr = pattern.sub('', domstr)
    return domstr


class Processor():
    def __init__(
        self,
    ):
        self.observation_type = "dom"
        
    def collect_tags(self, tag, tags):
        if isinstance(tag, Tag):
            tags.append(tag)
            for child in tag.children:
                self.collect_tags(child, tags)

    def convert_html(self, html_content):
        if "<body" in html_content:

            soup = BeautifulSoup(html_content, "html.parser")
            all_tags = []
            self.collect_tags(soup, all_tags)
            full_nmap = {}
            for i, tag in enumerate(all_tags[::-1]):
                tag["node"] = int(i)
                try:
                    full_nmap[str(i)] = tag["backend_node_id"]
                    del tag["backend_node_id"]
                except:
                    pass

            comments = soup.find_all(string=lambda text: isinstance(text, Comment))
            for comment in comments:
                comment.extract()

            full_html_doc = soup.prettify()
            full_html_doc = re.sub(r"\s+", " ", full_html_doc)
            num_op_tag = 0
            for tag in all_tags[1:]:
                if tag.name in code_elements_to_decompose:
                    tag.decompose()                
                elif tag.name not in valid_tags:
                    tag.unwrap()
                elif tag.name == "option" and tag.text.isdigit():
                    num_op_tag += 1
            if num_op_tag > 20:
                for tag in all_tags[1:]:
                    if tag.name == "option" and tag.text.isdigit():
                        tag.decompose()

            max_len = 32
            for tag in all_tags:
                if tag.attrs is None:
                    continue

                for attr in list(tag.attrs):
                    
                    if "require" in attr.lower() or "expand" in attr.lower() or "selected" in attr.lower():
                        continue

                    if len(str(tag[attr])) > max_len and token_ratio(str(tag[attr])) < 2:
                        del tag[attr]
                        continue

                    if "script" in attr.lower():
                        del tag[attr]
                        continue
                    if attr.lower() not in salient_attributes:
                        del tag[attr]
                        continue
                    elif (tag[attr] == "" or tag[attr] == "none"):
                        del tag[attr]
                        continue
                    if attr in tag:
                        if tag.name == "iframe":
                            if attr != "node":
                                del tag[attr]

            cleaned_html_doc = soup.prettify()
            cleaned_html_doc = re.sub(r"\s+", " ", cleaned_html_doc)

            return full_html_doc, full_nmap, cleaned_html_doc, full_nmap

        return html_content, None, html_content, None

    def basic_clean(self, input_str, is_acc=False):
        input_str = input_str.replace("–", '-')
        input_str = input_str.replace("•", '-')
        input_str = input_str.replace("’", '\'')
        input_str = input_str.replace("‹", '<')
        input_str = input_str.replace("×", '*')
        input_str = input_str.replace("·", '.')
        input_str = input_str.replace("”","\"")
        input_str = input_str.replace("&amp;","&")
        input_str = input_str.replace("&lt;","<")
        input_str = input_str.replace("&gt;",">")
        input_str = input_str.replace("＋", '+')
        input_str = re.sub(r'[^\x00-\x7F]+','', input_str)
        input_str = re.sub(u'[^\u0020-\uD7FF\u0009\u000A\u000D\uE000-\uFFFD\U00010000-\U0010FFFF]+', '', input_str)
        pattern = re.compile(r'[\ue000-\uf8ff]')
        input_str = pattern.sub('', input_str)
        input_str = re.sub(r"[ ]+", " ", input_str)
        input_str = re.sub(r"\n([^\n]+)StaticText \'\'\n", "\n", input_str)
        input_str = re.sub(r"\n([^\n]+)LineBreak \'\n\'\n", "\n", input_str)
        if not is_acc:
            input_str = re.sub(r"\s+", " ", input_str)

        return input_str

def print_without_children(element):
    element_string = f'<{element.tag}'
    for name, value in element.attrib.items():
        element_string += f' {name}="{value}"'
    element_string += '>'

    # Optionally, add element's text if it's not None or empty
    if element.text and element.text.strip():
        element_string += element.text.strip()

    element_string += f'</{element.tag}>'
    return element_string

def eval(model, tokenizer, raw_inp_pre, prevact, samples):

    generated = defaultdict(int)
    gtext = defaultdict(list)
    maxtry = 5
    allpred = []
    for sample in samples:
        raw_inp = raw_inp_pre + "Observation: " + sample + "\nStep-by-step guide:\n" + prevact
        seed = 1
        torch.cuda.empty_cache()
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed) 
        torch.cuda.manual_seed_all(seed)
        print(raw_inp)
    
        messages = [
                        {"role": "system", "content": PREPEND},
                        {"role": "user", "content": raw_inp}
                    ]
        input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        model_inputs = tokenizer(input_text, return_tensors="pt")
        input = {}
        for key, value in model_inputs.items():
            input[key] = model_inputs[key].to(model.device).reshape(1, -1)
                
        input_len = input["input_ids"].shape[1]
        if input_len > 40000:
            return None, None
        
        sidx = re.search("Objective:", raw_inp).start()
        
        for numtry in range(maxtry):
            generated_ids = model.generate(**input, max_new_tokens=200, do_sample=True, top_p=0.95, temperature=0.6, pad_token_id=tokenizer.eos_token_id)
            generated_ids=[generated_ids[0][input_len:]]
            generated_texts = tokenizer.batch_decode(generated_ids,skip_special_tokens=True)[0]
            
            if "Node" in generated_texts:
                sidx = re.search("Node: ", generated_texts).end()
                tid = generated_texts[sidx:]
                eidx = re.search("\n", tid).start()
                tid = tid[:eidx]
                try:
                    int(tid)
                except:
                    s = re.search("node=\"", generated_texts).end()
                    newt = generated_texts[s:]
                    e = re.search("\"", newt).start()
                    newt = newt[:e]
                    tid = newt
                print("[Candidate ", numtry, "]", generated_texts)
                generated[int(tid)]+=1
                gtext[int(tid)].append(generated_texts)

    sortedgenerated = sorted(generated.items(), key=lambda x: x[1]) 
    maxid = sortedgenerated[-1][0]
        
    generated_texts = gtext[maxid][0]
    print("[PRED]",generated_texts)

    return maxid, generated_texts

MAXLEN = 32768
PREPEND = "Help achieve the objective by generating the next step."
processor = Processor()

model_name_or_path="Qwen/Qwen2-7B-Instruct"

config = transformers.AutoConfig.from_pretrained(
    		model_name_or_path,
    		cache_dir="/home/ubuntu/.cache/huggingface/hub",
    		token=HF_TOKEN
    	)
orig_rope_scaling = getattr(config, "rope_scaling", None)
if orig_rope_scaling is None:
    orig_rope_scaling = {"factor": 1}
orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len:
    orig_ctx_len *= orig_rope_scaling_factor
if MAXLEN > orig_ctx_len:
    scaling_factor = float(math.ceil(MAXLEN / orig_ctx_len))
    config.rope_scaling = {"type": "linear", "factor": scaling_factor}

# Load model and tokenizer
model = transformers.AutoModelForCausalLM.from_pretrained(
    		model_name_or_path,
    		config=config,
    		cache_dir="/home/ubuntu/.cache/huggingface/hub",
    		torch_dtype=torch.bfloat16,
    		token=HF_TOKEN
    	)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    		model_name_or_path,
    		cache_dir="/home/ubuntu/.cache/huggingface/hub",
    		model_max_length=MAXLEN,
    		padding_side="left",
    		token=HF_TOKEN
    	)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
model.eval()
output_dir = "/home/ubuntu/qwen_model/fix_acc/aug29_all_data_accessibility"
model.load_adapter(output_dir)
model.load_state_dict(torch.load(f"{output_dir}/model.pt"), strict=False)
model.to("cuda:0")

filename = 'data/test_domain/test_domain_9'
with open(filename + '.json', 'r') as file:
    data = json.load(file)
    response = {}
    accs = []
    stepwise_acc = {}
    num_invalid = 0
    allaccs = []
    for wid, datapoint in enumerate(data):

        print("="*30, wid, len(data), num_invalid)
        url = datapoint["website"]
        if "." not in url:
            url = "https://www." + url + ".com"
        obj = datapoint["confirmed_task"]
        action_reprs = datapoint["action_reprs"]
        print(url, obj, action_reprs)
        response[wid] = {}

        actualsid=1
        prevact_pred = ""
        prevact_gt = ""
        all_correct = 1

        if ((wid)+1) % 10 == 0:
            with open(filename + "_response.json", "w") as f:
                json.dump(response, f)

        for sid, step in enumerate(datapoint["actions"]):
            print("-"*20, sid)
            action = step["operation"]["op"]
            des = action.lower()
            if len(step["operation"]["value"]) > 0:
                des += " " + step["operation"]["value"]
            action_repr = action_reprs[sid]
            

            if len(step["pos_candidates"]) == 0:
                fneids = step['action_uid']
            else:
                fneids = []
                for target in step["pos_candidates"]:
                    fneids.append(target["backend_node_id"])

            print(step["pos_candidates"], action_repr,fneids)

            html_content = step["raw_html"]
            html_content = clean_dom(html_content)
            full_html_doc, full_nmap, cleaned_html_doc, cleaned_nmap = processor.convert_html(html_content)
            cleaned_html_doc = processor.basic_clean(cleaned_html_doc)
            try:
                cleaned_html_doc = cleaned_html_doc[re.search("<body",cleaned_html_doc).start():re.search("</body>",cleaned_html_doc).end()]
            except:
                pass
            inv_cleaned_id_map = {v: k for k, v in cleaned_nmap.items()}

            if str(fneids[0]) not in inv_cleaned_id_map.keys():
                print("[NO GT]")
                num_invalid += 1
                continue
            gtid = inv_cleaned_id_map[str(fneids[0])]
            if "node=\"" + str(gtid) not in cleaned_html_doc:
                print("[NO GT]")
                num_invalid += 1
                continue
            try:
                cleaned_tree = lxml_html.fromstring(cleaned_html_doc)
                selected_element = cleaned_tree.xpath(f"//*[@node=\"{gtid}\"]")[0]
            except:
                print("[NO GT]")
                num_invalid += 1
                continue

            if len(tokenizer(cleaned_html_doc)["input_ids"]) > MAXLEN * 3:
                alllen = len(cleaned_html_doc)
                samples = []
                numchunk = len(tokenizer(cleaned_html_doc)["input_ids"])//MAXLEN
                for cidx in range(numchunk):
                    samples.append(cleaned_html_doc[cidx * int(alllen/numchunk):(cidx+1) * int(alllen/numchunk)])
                

            elif len(tokenizer(cleaned_html_doc)["input_ids"]) > MAXLEN - 2500:
                # continue
                samples = []
                alllines = re.split("(</[a-z]+> <[a-z])", cleaned_html_doc)
                
                sample = ""
                windowsize = MAXLEN - 2500
                boundary = ""
                lines = []
                for lidx in range(len(alllines)):
                    if lidx % 2 == 1:
                        lines.append(alllines[lidx - 1]+alllines[lidx])
                    else:
                        if lidx == len(alllines) - 1:
                            lines.append(alllines[lidx])
                prev_remaining = ""
         
                while len(lines) > 0:
                    if lines[0][-4:-1] == "> <":
                        line_to_add = lines[0][:-3]
                        remaining = lines[0][-2:]
                    else:
                        line_to_add = lines[0]
                        remaining = ""
                    samplenew = sample + prev_remaining + line_to_add
                    tl = len(tokenizer(samplenew)["input_ids"]) 
                    if tl > windowsize:
                        samples.append(sample)
                        sample = ""
                        if "node=\"" + str(gtid) in sample:
                            break
                    else:
                        sample = samplenew
                        lines = lines[1:]
                        prev_remaining = remaining


                if len(sample) > 0:
                    samples.append(sample)
                print("-"*15,"CHUNK DOM INTO", len(samples), "PIECES", "-"*15)
                
                sample = samples[0]                
                for samplenew in samples:
                    if "node=\"" + str(gtid) in samplenew:
                        sample = samplenew
                        print("[Manual select DOM]")
                samples = [sample]
            else:
                samples = [cleaned_html_doc]
            raw_inp = "Objective: " + obj + "\n" + \
                                     "URL: " + url + "\n" 


            if action == "CLICK" or action == "SELECT":
                action_new = "mouse_click_action"
                sidx = re.search("]  ", action_repr).end()
                eidx = re.search(" ->", action_repr).start()
                if action == "CLICK":
                    des = "Click \"" 
                else:
                    des = "Select \"" 
                if len(action_repr[sidx:eidx].strip()) > 0:
                    des += action_repr[sidx:eidx].strip() + "\""
                else:
                    des = "Click here"
            else:
                action_new = "keyboard_sequence_action"
                des = "T" + des[1:]
            
            selected_eid_rep = print_without_children(selected_element)
            selected_eid_rep=selected_eid_rep[:selected_eid_rep.find(">")+1]
            curstep = str(actualsid) + ".\nDescription: " + des + "\nAction: " + action_new +  "\nNode: " + str(gtid) + (" "+str(gtid))*4 + "\nTarget: " + str(selected_eid_rep) + "\n"
                
    
                
            try:
                maxid, pred = eval(model, tokenizer, raw_inp, prevact_gt, samples)
                if pred is None:
                    print("[Too Long]")
                    num_invalid += 1
                    continue
                pred_des = pred[re.search("Description: ", pred).end():re.search("\nAction:", pred).start() ]
                
                maxid_backend = cleaned_nmap[str(maxid)]

                acc = int(maxid_backend in fneids or pred_des.lower() == des.lower() or "node=\"" + str(maxid) in lxml_html.tostring(selected_element, pretty_print=True, encoding=str))
    
                prevact_pred += pred          

                accs.append(acc)
                print("[GT]",curstep)

                if actualsid not in stepwise_acc.keys():
                    stepwise_acc[actualsid] = [0,0]
                stepwise_acc[actualsid][1] += 1
                stepwise_acc[actualsid][0] += acc
                all_correct *= acc
                print("[ACC]",acc,np.mean(accs),len(accs),np.mean(allaccs),len(allaccs), stepwise_acc)
    
                response[wid][actualsid] = {"pred": pred,"label_id": tuple(fneids), "label":curstep, "pred_id":str(maxid_backend)}
    
                prevact_gt += curstep
                actualsid += 1
            except:
                prevact_gt += curstep
                actualsid += 1
                pass
        allaccs.append(all_correct)
        
    print("num invalid", num_invalid)
    print(np.mean(accs),len(accs), np.mean(allaccs),len(allaccs),stepwise_acc)
    with open(filename + "_response.json", "w") as f:
        json.dump(response, f)

